import yaml
from mltools import utils
from ultralytics import YOLO


# 训练模型
@utils.get_gpu
def train(device):
    with open("config.yaml", "r", encoding="utf-8") as file:
        data = yaml.safe_load(file)
    print(data)
    weight_path = data["train"]["weight_path"]
    data_path = data["train"]["data_path"]
    if data["train"]["device"] is not None:
        device = data["train"]["device"]
    model = YOLO(weight_path)
    model.train(data=data_path, epochs=100, imgsz=1024, batch=-1, amp=False, device=device)


if __name__ == "__main__":
    train()
